module ModuleOptTransition
    
    contains
    
function welfare_trans(param_in_vec)

    use Globals
    use ModuleSolveEqm
    use ModuleQTRANSITION

    implicit none
    
    ! Variable declarations
    real(8), intent(in) :: param_in_vec(3)      ! input parameters
    real(8) :: welfare_trans                    ! welfare taking into account transition
    real(8) :: welfare                          ! steady state welfare
    real(8) :: param_vec_init(3)                ! parameters initial steady state
    real(8) :: param_vec_lr(3)                  ! parameters initial steady state
    real(8) :: eqm_results(5)
    
    ! Set tax function parameters
    open(11,  file = 'Input/eqm_results.txt',      status = 'unknown')
    read(11, *) eqm_results
    close(11)
    norm_version = 1            ! indicator for choice of normalization: 1) mean income; 2) median income; 3) mean income of the third quintile
    param_vec_init(1) = eqm_results(3)     ! gamma
    param_vec_init(2) = eqm_results(2)     ! lambda
    param_vec_init(3) = eqm_results(5)     ! rt

    print *, 'gamma, lambda, rt = ', param_vec_init

    mt     =  eqm_results(4)           ! mt
    r      =  eqm_results(1)           ! interest rate
    
    wge = (1.0D0-alpha)/(r+delta)
    wge = alpha*(wge**((1D0-alpha)/alpha))
    
    ! Solve for initial steady state
    omegat            = 0d0        ! omegat
    welfare = welfare_eqm(param_vec_init,1)
    
    ! Store eqm objects
    r_init          = r
    mt_init         = mt
    error_vec_init  = error_vec
    Kagg_init = Kagg
    
    ! Solve for final steady state
    ! Final steady state
    param_vec_lr(1) = param_in_vec(1)
    param_vec_lr(2) = param_in_vec(2)
    param_vec_lr(3) = param_in_vec(3)
    welfare = welfare_eqm(param_vec_lr,2)
        
    ! Store
    r_lr          = r
    lambda_lr = lambda
    mt_lr     =     mt
    error_vec_lr  = error_vec
    
    ! If steady state converged: transition
    if (error_vec(3) < 5d-10 .AND. error_vec(4) < 1d-11 .AND. error_vec(1) < 1d-5 .AND. error_vec(2) < 1d-5) then
                
        ! Transition
        r_TR      = r_lr
        lambda_TR = lambda_lr
        mt_TR = mt_lr
        wge_TR = (1.0D0-alpha)/(r_TR+delta)
        wge_TR = alpha*(wge_TR**((1d0-alpha)/alpha))
        
        call Compute_JACOB
        call Compute_QNEWTON
        
        if (maxval(abs(asset_TR)) < 1d-5 .AND. maxval(abs(BC_TR)) < 1d-5) then
            welfare_trans = -sum(mu_TR(:,1)*V_TR(:,1))
        else
            welfare_trans = 10000d0
        endif
        
    else
        
        welfare_trans = 10000d0
        
    endif
    
    print *, 'welfare_trans = ', welfare_trans
    
end function

end module